{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "outputs": [],
   "source": [
    "FILE = \"fashion-mnist_train.csv\"\n",
    "with open(f\"data/{FILE}\") as f:\n",
    "    examples = f.read().strip().split(\"\\n\")[1:]\n",
    "\n",
    "train_data, valid_data, test_data = [], [], []\n",
    "valid_cutoff = .8\n",
    "test_cutoff = .9\n",
    "\n",
    "for example in examples:\n",
    "    label, pixels = example.split(\",\", 1)\n",
    "    label = int(label)\n",
    "    if label > 4:\n",
    "        continue\n",
    "    pixels = [int(p) for p in pixels.split(\",\")]\n",
    "    data = np.asarray(pixels, dtype=\"int32\")\n",
    "    data = data.reshape((1,28,28))\n",
    "    # Scale values from -1 to 1\n",
    "    data = ((data / 255) - .5) / .5\n",
    "    # Move channel axis to first axis\n",
    "    #data = np.moveaxis(data, -1, 0)\n",
    "    target = np.zeros((5))\n",
    "    target[label] = 1\n",
    "    row = (data, target, )\n",
    "\n",
    "    split = random.random()\n",
    "    if split > test_cutoff:\n",
    "        test_data.append(row)\n",
    "    elif split > valid_cutoff:\n",
    "        valid_data.append(row)\n",
    "    else:\n",
    "        train_data.append(row)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "device = torch.device(\"cpu\")\n",
    "\n",
    "from torchvision.io import read_image\n",
    "from torchvision import transforms as T\n",
    "from torch.utils.data import Dataset\n",
    "import math\n",
    "\n",
    "class ImageDataset(Dataset):\n",
    "    def __init__(self, dataset):\n",
    "        self.dataset = dataset\n",
    "        self.normalize = T.Compose([\n",
    "            T.ConvertImageDtype(torch.float)\n",
    "        ])\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "\n",
    "    def classes(self):\n",
    "        return 5\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        data, target = self.dataset[idx]\n",
    "        data = torch.from_numpy(data)\n",
    "        return self.normalize(data), target"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from torch import nn\n",
    "\n",
    "BATCH_SIZE = 1\n",
    "EPOCHS = 20\n",
    "\n",
    "train_dataset = ImageDataset(train_data)\n",
    "valid_dataset = ImageDataset(valid_data)\n",
    "test_dataset = ImageDataset(test_data)\n",
    "\n",
    "train = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "valid = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "test = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "outputs": [],
   "source": [
    "class NeuralNetwork(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(NeuralNetwork, self).__init__()\n",
    "\n",
    "        self.cnn = nn.Sequential(\n",
    "            nn.Conv2d(1, 1, 3),\n",
    "            nn.ReLU(True),\n",
    "        )\n",
    "\n",
    "        self.dense = nn.Sequential(\n",
    "            nn.Linear(26 ** 2, 5)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.cnn(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.dense(x)\n",
    "        return x"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 train loss: 0.13648992776870728\n",
      "Valid accuracy: 0.8597640991210938\n",
      "Epoch 1 train loss: 0.24475865066051483\n",
      "Valid accuracy: 0.8725426197052002\n",
      "Epoch 2 train loss: 0.13204717636108398\n",
      "Valid accuracy: 0.8633682727813721\n",
      "Epoch 3 train loss: 0.0027682576328516006\n",
      "Valid accuracy: 0.8574705123901367\n",
      "Epoch 4 train loss: 0.06656605750322342\n",
      "Valid accuracy: 0.8797509670257568\n",
      "Epoch 5 train loss: 0.711868941783905\n",
      "Valid accuracy: 0.8673001527786255\n",
      "Epoch 6 train loss: 1.0944929122924805\n",
      "Valid accuracy: 0.8617300391197205\n",
      "Epoch 7 train loss: -0.0\n",
      "Valid accuracy: 0.8620576858520508\n",
      "Epoch 8 train loss: 0.05632047355175018\n",
      "Valid accuracy: 0.8650065660476685\n",
      "Epoch 9 train loss: 0.0006551980040967464\n",
      "Valid accuracy: 0.8705766797065735\n",
      "Epoch 10 train loss: 0.021929049864411354\n",
      "Valid accuracy: 0.8764744400978088\n",
      "Epoch 11 train loss: 0.0053314645774662495\n",
      "Valid accuracy: 0.8718872666358948\n",
      "Epoch 12 train loss: 0.07228444516658783\n",
      "Valid accuracy: 0.8751638531684875\n",
      "Epoch 13 train loss: 0.003100709058344364\n",
      "Valid accuracy: 0.8751638531684875\n",
      "Epoch 14 train loss: 0.044329818338155746\n",
      "Valid accuracy: 0.8705766797065735\n",
      "Epoch 15 train loss: 3.099436753473128e-06\n",
      "Valid accuracy: 0.8689383864402771\n",
      "Epoch 16 train loss: 0.00015162272029556334\n",
      "Valid accuracy: 0.8784403800964355\n",
      "Epoch 17 train loss: 0.013874982483685017\n",
      "Valid accuracy: 0.8728702664375305\n",
      "Epoch 18 train loss: 1.2067996263504028\n",
      "Valid accuracy: 0.8735255599021912\n",
      "Epoch 19 train loss: 0.020493200048804283\n",
      "Valid accuracy: 0.8745085000991821\n"
     ]
    }
   ],
   "source": [
    "model = NeuralNetwork().to(device)\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=5e-3)\n",
    "\n",
    "size = len(train.dataset)\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "    for batch, (images, targets) in enumerate(train):\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        images = images.to(device)\n",
    "        pred = model(images.float())\n",
    "\n",
    "        loss = loss_fn(pred, targets)\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    loss = loss.item()\n",
    "    print(f\"Epoch {epoch} train loss: {loss}\")\n",
    "\n",
    "    match = list()\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for batch, (images, targets) in enumerate(valid):\n",
    "            images = images.to(device)\n",
    "            outputs = model(images.float())\n",
    "\n",
    "            _, p = torch.max(outputs.data, 1)\n",
    "\n",
    "            p = p.cpu().numpy()\n",
    "            match.append(p[0] == np.argmax(targets, axis=1)[0])\n",
    "\n",
    "    print(f\"Valid accuracy: {sum(match) / len(match)}\")"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
